'''
Author: SlytherinGe
LastEditTime: 2021-06-23 17:31:45
'''
import mmcv
import matplotlib.pyplot as plt
import numpy as np
import os

def load_feature_maps(feature_file):
    '''
    feature_file: the file that stores all feature maps
    '''
    feature_maps = mmcv.load(feature_file)

    return feature_maps


def visualize_feature_maps(feature_maps, vis_feat_lvl, save_folder=None):
    '''
    feature_maps: a dict that contains feature maps from different levels  
    vis_feat_lvl: a list that contains the feature levels to be displayed,
    eg. [0, 3], means that the feature maps whose indexes are 0 and 3 will
    be displayed
    '''

    if save_folder is not None:
        if not os.path.exists(save_folder):
            os.mkdir(save_folder)        

    feature_levels = len(feature_maps)
    for lvl in vis_feat_lvl:
        assert 0<=lvl<feature_levels
        feat = np.array(feature_maps[lvl].cpu()).squeeze(0)
        num_ch = feat.shape[0]
        plt.figure()
        for i in range(num_ch):
            plt.subplot(4, 5, i%20+1)
            plt.imshow(feat[i,:,:])
            plt.colorbar(fraction=0.05, pad=0.05)
            plt.axis('off')
            plt.title('ch {} size {}'.format(i, feat.shape[1:]))
            if (i+1)%20 == 0:
                if save_folder is not None:
                    plt.savefig(os.path.join(save_folder,"{}-{}.jpg".format(lvl, (i+1)//20)))
                else:
                    plt.show()
        if save_folder is not None:
            plt.savefig(os.path.join(save_folder,"{}-last.jpg".format(lvl)))
        else:
            plt.show()

def visualize_feature_maps_diff_scale(feature_maps, vis_feat_lvl, shape):
    '''
    feature_maps: a dict that contains feature maps from different levels  
    vis_feat_lvl: a list that contains the feature levels to be displayed,
    eg. [0, 3], means that the feature maps whose indexes are 0 and 3 will
    be displayed
    shape: a tuple for subplot
    '''    
    row, col = shape 
    size = shape[0] * shape[1]
    plt.figure()
    feature_levels = len(feature_maps)
    cnt = 0
    for lvl in vis_feat_lvl:
        assert 0<=lvl<feature_levels
        feat = np.array(feature_maps[lvl].cpu()).squeeze(0)
        num_ch = feat.shape[0]
        for i in range(num_ch):
            plt.subplot(row, col, cnt%size+1)
            plt.imshow(feat[i,:,:])
            plt.colorbar(fraction=0.05, pad=0.05)
            plt.axis('off')
            plt.title('feat {} size {}'.format(cnt, feat.shape[1:]))
            cnt += 1
            if (cnt)%size == 0:
                plt.show()
    if (cnt)%size != 0:
        plt.show()



if __name__ == '__main__':

    PAOI_SAVE_FOLDER = '/media/gejunyao/Disk/Gejunyao/exp_results/visualization/middle_part/paoi_feats/'
    PAOI_NUM = 5

    paois = os.listdir(PAOI_SAVE_FOLDER)

    paoi_list = []

    for paoi in paois:
        paoi_list.extend(list(load_feature_maps(os.path.join(PAOI_SAVE_FOLDER, paoi)))[:PAOI_NUM])

    # visualize_feature_maps_diff_scale(paoi_list, [0,3,6,9,1,4,7,10,2,5,8,11], (3,4))
    visualize_feature_maps_diff_scale(paoi_list, [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19], (4,5))
    # feature_map = load_feature_maps('/media/gejunyao/Disk/Gejunyao/exp_results/visualization/middle_part/paoi_feats.pkl')

    # attention_map = load_feature_maps('/media/gejunyao/Disk/Gejunyao/exp_results/visualization/middle_part/hrfpn_data.pkl')

    # visualize_feature_maps_diff_scale(feature_map, [0,1,2], (1,3))
